//===--- ThunkLowering.cpp ------------------------------------------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2024 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

#include "swift/AST/ASTMangler.h"
#include "swift/Basic/Defer.h"
#include "swift/SIL/SILBuilder.h"

#include "swift/SILOptimizer/PassManager/Passes.h"
#include "swift/SILOptimizer/PassManager/Transforms.h"
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"

using namespace swift;

//===----------------------------------------------------------------------===//
//                              MARK: Utilities
//===----------------------------------------------------------------------===//

static CanSILFunctionType
getThunkFunctionType(ThunkInst::Kind kind, CanSILFunctionType inputFunctionType,
                     SILModule &mod) {
  switch (kind) {
  case ThunkInst::Kind::Invalid:
    return CanSILFunctionType();

  case ThunkInst::Kind::Identity: {
    // Our thunk type is a thin function that takes the input function type and
    // the input function type's parameters.
    llvm::SmallVector<SILParameterInfo, 8> newParameters;
    for (SILParameterInfo p : inputFunctionType->getParameters()) {
      newParameters.push_back(p);
    }
    newParameters.push_back(SILParameterInfo(
        inputFunctionType, ParameterConvention::Direct_Guaranteed));

    SILExtInfoBuilder builder;
    builder = builder.withRepresentation(SILFunctionTypeRepresentation::Thin)
                  .withAsync(inputFunctionType->isAsync());
    return SILFunctionType::get(
        inputFunctionType->getInvocationGenericSignature(), builder.build(),
        inputFunctionType->getCoroutineKind(),
        ParameterConvention::Direct_Unowned, newParameters,
        inputFunctionType->getYields(), inputFunctionType->getResults(),
        inputFunctionType->getOptionalErrorResult(),
        inputFunctionType->getPatternSubstitutions(),
        inputFunctionType->getInvocationSubstitutions(),
        inputFunctionType->getASTContext());
  }
  case ThunkInst::Kind::HopToMainActorIfNeeded: {
    // Our thunk type is a thin function that takes the input function type and
    // the input function type's parameters.
    llvm::SmallVector<SILParameterInfo, 8> newParameters;
    for (SILParameterInfo p : inputFunctionType->getParameters()) {
      newParameters.push_back(p);
    }
    newParameters.push_back(SILParameterInfo(
        inputFunctionType, ParameterConvention::Direct_Guaranteed));

    SILExtInfoBuilder builder;
    builder = builder.withRepresentation(SILFunctionTypeRepresentation::Thin)
                  .withAsync(inputFunctionType->isAsync());
    return SILFunctionType::get(
        inputFunctionType->getInvocationGenericSignature(), builder.build(),
        inputFunctionType->getCoroutineKind(),
        ParameterConvention::Direct_Unowned, newParameters,
        inputFunctionType->getYields(), inputFunctionType->getResults(),
        inputFunctionType->getOptionalErrorResult(),
        inputFunctionType->getPatternSubstitutions(),
        inputFunctionType->getInvocationSubstitutions(),
        inputFunctionType->getASTContext());
  }
  }

  llvm_unreachable("Covered switch isn't covered?!");
}

namespace {

struct ThunkBodyBuilder {
  SILBuilder builder;
  SILFunction *thunk;
  llvm::SmallVector<SILValue, 8> thunkArguments;

  /// We always prepare the first block so we can have the builder setup.
  ThunkBodyBuilder(SILFunction *thunk)
      : builder(thunk->createBasicBlock()), thunk(thunk) {}

  /// Default implementation. Just forwards the arguments. Intended to be used
  /// as an example for other more complex generates by builders that compose
  /// with ThunkBodyBuilder.
  void generate();

  /// Return a loc for our thunks. Is autogenerated so it doesnt show up in
  /// debug info. Just to make it quicker to type.
  SILLocation getLoc() const {
    return RegularLocation::getAutoGeneratedLocation();
  }

  CanSILFunctionType getThunkType() const {
    return thunk->getLoweredFunctionType();
  }

  SILFunctionConventions getThunkConventions() const {
    return SILFunctionConventions(getThunkType(), thunk->getModule());
  }

  /// Create the entry block arguments for the function.
  void createEntryBlockArguments();

  /// Create a call to the thunked function.
  void callThunkedFunction(SILValue function, ArrayRef<SILValue> arguments);

private:
  void callApplyThunkedFunction(SILValue function,
                                ArrayRef<SILValue> arguments);
  void callTryApplyThunkedFunction(SILValue function,
                                   ArrayRef<SILValue> arguments);
  void callBeginApplyThunkedFunction(SILValue function,
                                     ArrayRef<SILValue> arguments);
};

} // namespace

void ThunkBodyBuilder::createEntryBlockArguments() {
  auto conventions = getThunkConventions();
  auto *block = &thunk->front();
  assert(block->getNumArguments() == 0 && "entry should be uninitialized");

  // First add our indirect results.
  for (auto indirectResult : conventions.getIndirectSILResults()) {
    SILType ty = conventions.getSILType(indirectResult,
                                        thunk->getTypeExpansionContext());
    ty = thunk->mapTypeIntoContext(ty);
    thunkArguments.push_back(block->createFunctionArgument(ty));
  }

  // Then add our normal parameters.
  for (auto paramInfo : conventions.getParameters()) {
    SILType ty =
        conventions.getSILType(paramInfo, thunk->getTypeExpansionContext());
    ty = thunk->mapTypeIntoContext(ty);
    thunkArguments.push_back(block->createFunctionArgument(ty));
  }
}

void ThunkBodyBuilder::callApplyThunkedFunction(SILValue function,
                                                ArrayRef<SILValue> arguments) {
  auto *ai =
      builder.createApply(RegularLocation::getAutoGeneratedLocation(), function,
                          thunk->getForwardingSubstitutionMap(), arguments);
  builder.createReturn(RegularLocation::getAutoGeneratedLocation(), ai);
}

void ThunkBodyBuilder::callBeginApplyThunkedFunction(
    SILValue function, ArrayRef<SILValue> arguments) {
  auto calleeType = function->getType().castTo<SILFunctionType>();
  auto conventions = SILFunctionConventions(calleeType, *function->getModule());
  auto *ai = builder.createBeginApply(
      RegularLocation::getAutoGeneratedLocation(), function,
      thunk->getForwardingSubstitutionMap(), arguments);

  auto *resumeBlock = thunk->createBasicBlock();
  {
    SILBuilder resumeBlockBuilder(resumeBlock);
    llvm::SmallVector<TupleTypeElt> directResultTypes;

    for (auto result : conventions.getDirectSILResults()) {
      auto ty =
          conventions.getSILType(result, thunk->getTypeExpansionContext());
      ty = thunk->mapTypeIntoContext(ty);
      directResultTypes.push_back(ty.getASTType());
    }

    SILType resultType = SILType::getEmptyTupleType(thunk->getASTContext());
    if (directResultTypes.size() == 1) {
      resultType = SILType::getPrimitiveObjectType(
          directResultTypes.front().getType()->getCanonicalType());
    } else if (directResultTypes.size() > 1) {
      auto tupleTy = TupleType::get(directResultTypes, thunk->getASTContext())
                         ->getCanonicalType();
      resultType = SILType::getPrimitiveObjectType(tupleTy);
    }

    SILValue result = resumeBlockBuilder.createEndApply(
        getLoc(), ai->getTokenResult(), resultType);
    resumeBlockBuilder.createReturn(getLoc(), result);
  }

  auto *unwindBlock = thunk->createBasicBlock();
  {
    SILBuilder unwindBlockBuilder(unwindBlock);
    unwindBlockBuilder.createAbortApply(getLoc(), ai->getTokenResult());
    unwindBlockBuilder.createUnwind(getLoc());
  }

  // Then return the result of applying the method.
  llvm::SmallVector<SILValue, 8> yieldedValues;
  copy(ai->getYieldedValues(), std::back_inserter(yieldedValues));
  builder.createYield(RegularLocation::getAutoGeneratedLocation(),
                      yieldedValues, resumeBlock, unwindBlock);
}

void ThunkBodyBuilder::callTryApplyThunkedFunction(
    SILValue function, ArrayRef<SILValue> arguments) {
  // Then handle the try_apply case.
  auto calleeType = function->getType().castTo<SILFunctionType>();
  auto conventions = SILFunctionConventions(calleeType, *function->getModule());

  // Create our normal block.
  auto *normalBlock = thunk->createBasicBlock();
  {
    // Create the argument for our direct results.
    llvm::SmallVector<TupleTypeElt, 8> normalBlockArgs;
    for (auto result : conventions.getDirectSILResults()) {
      auto ty =
          conventions.getSILType(result, thunk->getTypeExpansionContext());
      ty = thunk->mapTypeIntoContext(ty);
      normalBlockArgs.push_back(ty.getASTType());
    }

    SILValue resultValue;
    if (normalBlockArgs.empty()) {
      // If our normal block args is empty, generate an empty tuple typed
      // argument since SIL requires at least one argument.
      resultValue = normalBlock->createPhiArgument(
          SILType::getEmptyTupleType(thunk->getASTContext()),
          OwnershipKind::None);
    } else if (normalBlockArgs.size() == 1) {
      // If we have a single direct result, just emit a phi argument directly of
      // that type.
      resultValue = normalBlock->createPhiArgument(
          SILType::getPrimitiveObjectType(
              normalBlockArgs.front().getType()->getCanonicalType()),
          OwnershipKind::Owned);
    } else {
      // Otherwise, create a tuple type and one single phi argument for all of
      // our values.
      auto tupleType = TupleType::get(normalBlockArgs, thunk->getASTContext())
                           ->getCanonicalType();
      resultValue = normalBlock->createPhiArgument(
          SILType::getPrimitiveObjectType(tupleType), OwnershipKind::Owned);
    }
    assert(resultValue);

    SILBuilder normalBlockBuilder(normalBlock);
    normalBlockBuilder.createReturn(getLoc(), resultValue);
  }

  // Then create our error block.
  auto *errorBlock = thunk->createBasicBlock();
  {
    auto errorArg = errorBlock->createPhiArgument(
        SILType::getExceptionType(thunk->getASTContext()),
        OwnershipKind::Owned);
    SILBuilder errorBlockBuilder(errorBlock);
    errorBlockBuilder.createThrow(getLoc(), errorArg);
  }

  // Finally wire up the try apply.
  builder.createTryApply(RegularLocation::getAutoGeneratedLocation(), function,
                         thunk->getForwardingSubstitutionMap(), arguments,
                         normalBlock, errorBlock);
}

void ThunkBodyBuilder::callThunkedFunction(SILValue function,
                                           ArrayRef<SILValue> arguments) {
  auto calleeType = function->getType().castTo<SILFunctionType>();

  // First see if we have a coroutine.
  if (calleeType->getCoroutineKind() == SILCoroutineKind::YieldOnce) {
    return callBeginApplyThunkedFunction(function, arguments);
  }

  // Then see if we have a normal apply.
  if (!calleeType->hasErrorResult()) {
    return callApplyThunkedFunction(function, arguments);
  }

  // Finally handle try_apply.
  return callTryApplyThunkedFunction(function, arguments);
}

void ThunkBodyBuilder::generate() {
  createEntryBlockArguments();
  callThunkedFunction(thunkArguments.back(),
                      ArrayRef<SILValue>(thunkArguments).drop_back());
}

//===----------------------------------------------------------------------===//
//                               MARK: Identity
//===----------------------------------------------------------------------===//

namespace {

struct IdentityLowering {
  SILOptFunctionBuilder &funcBuilder;
  ThunkInst *ti;

  // The number of thunks emitted into the function. Just an easy way to give
  // multiple thunks in a function a unique name for prototyping purposes.
  unsigned &thunkCount;

  IdentityLowering(SILOptFunctionBuilder &funcBuilder, ThunkInst *ti,
                   unsigned &thunkCount)
      : funcBuilder(funcBuilder), ti(ti), thunkCount(thunkCount) {}

  void lower() &&;

  void invalidate() {
    ti->eraseFromParent();
    ti = nullptr;
  };

  ~IdentityLowering() {
    assert(!ti && "Did not call consuming method to destroy value");
  }

  SILFunction *createThunk() const;
};

} // namespace

void IdentityLowering::lower() && {
  SWIFT_DEFER { invalidate(); };

  // Create the thunk.
  auto *thunk = createThunk();

  SILBuilderWithScope builder(ti);
  SingleValueInstruction *thunkValue =
      builder.createFunctionRef(ti->getLoc(), thunk);

  thunkValue = builder.createPartialApply(
      ti->getLoc(), thunkValue, ti->getSubstitutionMap(), ti->getOperand(),
      ParameterConvention::Direct_Guaranteed);

  ti->replaceAllUsesWith(thunkValue);
}

SILFunction *IdentityLowering::createThunk() const {
  // Our type is going to be the result of the function.
  auto inputFuncType = ti->getOperand()->getType().getAs<SILFunctionType>();

  // We need to add our input type as a parameter and have our result type as
  // the result of the function type.
  GenericSignature genericSig;
  auto thunkType =
      getThunkFunctionType(ti->getThunkKind(), inputFuncType, ti->getModule());

  Mangle::ASTMangler mangler(ti->getModule().getASTContext());
  auto name = mangler.mangleSILThunkKind(ti->getFunction()->getName(),
                                         ThunkInst::Kind::Identity);

  auto *fn = funcBuilder.getOrCreateSharedFunction(
      RegularLocation::getAutoGeneratedLocation(), name, thunkType,
      IsBare_t::IsNotBare, IsTransparent_t::IsNotTransparent,
      SerializedKind_t::IsNotSerialized, ProfileCounter(), IsThunk_t::IsThunk,
      IsDynamicallyReplaceable_t::IsNotDynamic,
      IsDistributed_t::IsNotDistributed,
      IsRuntimeAccessible_t::IsNotRuntimeAccessible);

  // Check if we already have a body. In such a case, we already codegened... so
  // just return the function.
  if (!fn->empty())
    return fn;

  // Otherwise, we need to generate the body.

  // These are only generated when not in Ownership SSA. Turn off Ownership SSA
  // so that SILBuilder and other utilities do the right thing and so we can
  // avoid having to run ownership lowering.
  fn->setOwnershipEliminated();

  // Set up our generic environment to be the same as our original function.
  fn->setGenericEnvironment(ti->getFunction()->getGenericEnvironment());

  // Move the thunk to be right before the generated function to ease FileCheck.
  fn->getModule().moveBefore(ti->getFunction()->getIterator(), fn);

  // Generate the body of the function.
  ThunkBodyBuilder thunkBodyBuilder(fn);
  thunkBodyBuilder.generate();

  return fn;
}

//===----------------------------------------------------------------------===//
//                     MARK: Hop To Main Actor If Needed
//===----------------------------------------------------------------------===//

namespace {

struct HopToMainActorIfNeededThunkBodyBuilder : ThunkBodyBuilder {
  HopToMainActorIfNeededThunkBodyBuilder(SILFunction *fn)
      : ThunkBodyBuilder(fn) {}

  void generate();

  SILValue adjustCalleeType(FunctionRefInst *runOnMainActor, SILValue callee);
};

} // namespace

SILValue HopToMainActorIfNeededThunkBodyBuilder::adjustCalleeType(
    FunctionRefInst *runOnMainActor, SILValue originalCallee) {
  SILValue callee = originalCallee;
  auto calleeType = callee->getType().castTo<SILFunctionType>();

  // If our original function is thin. Convert it to be a thick function.
  if (calleeType->getRepresentation() == SILFunctionTypeRepresentation::Thin) {
    auto extInfoBuilder = calleeType->getExtInfo().intoBuilder();
    auto extInfo =
        extInfoBuilder.withRepresentation(SILFunctionTypeRepresentation::Thick)
            .build();
    calleeType = calleeType->getWithExtInfo(extInfo);
    callee = builder.createThinToThickFunction(
        getLoc(), callee, SILType::getPrimitiveObjectType(calleeType));
  }

  return callee;
}

void HopToMainActorIfNeededThunkBodyBuilder::generate() {
  createEntryBlockArguments();

  assert(thunkArguments.size() == 1 && "We only support no thunk arguments");

  // Create a function ref for _runOnMainActor. We have to use link all to
  // ensure that we link in the closures we create so that hop to main executor
  // lowering runs on them.
  StringLiteral value = "$ss19_taskRunOnMainActor9operationyyyScMYcc_tF";
  auto *function =
      builder.getModule().loadFunction(value, SILModule::LinkingMode::LinkAll);
  assert(function && "Cannot find runOnMainActor");

  // Create the function ref.
  auto *fri = builder.createFunctionRef(getLoc(), function);

  callThunkedFunction(fri, {adjustCalleeType(fri, thunkArguments.back())});
}

namespace {

struct HopToMainActorIfNeededLowering {
  SILOptFunctionBuilder &funcBuilder;
  ThunkInst *ti;

  // The number of thunks emitted into the function. Just an easy way to give
  // multiple thunks in a function a unique name for prototyping purposes.
  unsigned &thunkCount;

  HopToMainActorIfNeededLowering(SILOptFunctionBuilder &funcBuilder,
                                 ThunkInst *ti, unsigned &thunkCount)
      : funcBuilder(funcBuilder), ti(ti), thunkCount(thunkCount) {}

  void lower() &&;

  void invalidate() {
    ti->eraseFromParent();
    ti = nullptr;
  };

  ~HopToMainActorIfNeededLowering() {
    assert(!ti && "Did not call consuming method to destroy value");
  }

  SILFunction *createThunk() const;
};

} // namespace

void HopToMainActorIfNeededLowering::lower() && {
  SWIFT_DEFER { invalidate(); };

  // Create the thunk.
  auto *thunk = createThunk();

  SILBuilderWithScope builder(ti);
  SingleValueInstruction *thunkValue =
      builder.createFunctionRef(ti->getLoc(), thunk);

  thunkValue = builder.createPartialApply(
      ti->getLoc(), thunkValue, ti->getSubstitutionMap(), ti->getOperand(),
      ParameterConvention::Direct_Guaranteed);

  ti->replaceAllUsesWith(thunkValue);
}

SILFunction *HopToMainActorIfNeededLowering::createThunk() const {
  // Our type is going to be the result of the function.
  auto inputFuncType = ti->getOperand()->getType().getAs<SILFunctionType>();

  // We need to add our input type as a parameter and have our result type as
  // the result of the function type.
  GenericSignature genericSig;
  auto thunkType =
      getThunkFunctionType(ti->getThunkKind(), inputFuncType, ti->getModule());

  Mangle::ASTMangler mangler(ti->getModule().getASTContext());
  auto name = mangler.mangleSILThunkKind(
      ti->getFunction()->getName(), ThunkInst::Kind::HopToMainActorIfNeeded);

  auto *fn = funcBuilder.getOrCreateSharedFunction(
      RegularLocation::getAutoGeneratedLocation(), name, thunkType,
      IsBare_t::IsNotBare, IsTransparent_t::IsNotTransparent,
      SerializedKind_t::IsNotSerialized, ProfileCounter(), IsThunk_t::IsThunk,
      IsDynamicallyReplaceable_t::IsNotDynamic,
      IsDistributed_t::IsNotDistributed,
      IsRuntimeAccessible_t::IsNotRuntimeAccessible);

  // If we already have a body, we already generated a thunk for this
  // function. Just return it.
  if (!fn->empty())
    return fn;

  // These are only generated when not in Ownership SSA. Turn off Ownership SSA
  // so that SILBuilder and other utilities do the right thing and so we can
  // avoid having to run ownership lowering.
  fn->setOwnershipEliminated();

  // Set up our generic environment to be the same as our original function.
  fn->setGenericEnvironment(ti->getFunction()->getGenericEnvironment());

  // Move the thunk to be right before the generated function to ease FileCheck.
  fn->getModule().moveBefore(ti->getFunction()->getIterator(), fn);

  // Generate the body of the function.
  HopToMainActorIfNeededThunkBodyBuilder thunkBodyBuilder(fn);
  thunkBodyBuilder.generate();

  return fn;
}

//===----------------------------------------------------------------------===//
//                         MARK: Top Level Entrypoint
//===----------------------------------------------------------------------===//

namespace {

class ThunkLoweringPass : public SILModuleTransform {
  void run() override {
    auto *mod = getModule();

    SILOptFunctionBuilder funcBuilder(*this);

    // A per function thunk count that can be used to quickly generate unique
    // per function names for thunks by appending the name of the thunk and a
    // count to the function name. Increment this every time you create a thunk.
    unsigned thunkCount = 0;

    // It is assumed in this code that the only instruction we delete is the
    // thunk itself. We leave cleaning everything else up to other passes just
    // to make the invalidation rules in this pass simple.
    bool createdThunk = false;

    for (auto fi = mod->begin(), fe = mod->end(); fi != fe;) {
      auto *fn = &*fi;
      ++fi;

      for (auto &block : *fn) {
        for (auto ii = block.begin(), ie = block.end(); ii != ie;) {
          auto *ti = dyn_cast<ThunkInst>(&*ii);
          ++ii;

          if (!ti)
            continue;

          switch (ti->getThunkKind()) {
          case ThunkInst::Kind::Invalid:
            llvm_unreachable("Should never see an invalid kind");
          case ThunkInst::Kind::Identity: {
            IdentityLowering lowering(funcBuilder, ti, thunkCount);
            std::move(lowering).lower();
            createdThunk = true;
            continue;
          }
          case ThunkInst::Kind::HopToMainActorIfNeeded: {
            HopToMainActorIfNeededLowering lowering(funcBuilder, ti,
                                                    thunkCount);
            std::move(lowering).lower();
            createdThunk = true;
            continue;
          }
          }

          llvm_unreachable("Covered switch isn't covered?!");
        }
      }
    }

    if (createdThunk)
      invalidateAll();
  }
};

} // namespace

SILTransform *swift::createThunkLowering() { return new ThunkLoweringPass(); }
